{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.5"
    },
    "colab": {
      "name": "charsiu_demo.ipynb",
      "provenance": []
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "0zxKOeyTROc2"
      },
      "source": [
        "!pip install torch torchvision torchaudio\n",
        "!pip install datasets transformers\n",
        "!pip install g2p_en praatio librosa"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DIROcsj7Rv4g",
        "outputId": "4cda97cf-0ed3-4027-cf17-800c4d0816d3"
      },
      "source": [
        "import os\n",
        "from os.path import exists, join, expanduser\n",
        "\n",
        "os.chdir(expanduser(\"~\"))\n",
        "charsiu_dir = 'charsiu'\n",
        "if exists(charsiu_dir):\n",
        "  !rm -rf /root/charsiu\n",
        "if not exists(charsiu_dir):\n",
        "  ! git clone -b development https://github.com/lingjzhu/$charsiu_dir\n",
        "  ! cd charsiu && git checkout && cd -\n",
        "  \n",
        "os.chdir(charsiu_dir)    "
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'charsiu'...\n",
            "remote: Enumerating objects: 308, done.\u001b[K\n",
            "remote: Counting objects: 100% (308/308), done.\u001b[K\n",
            "remote: Compressing objects: 100% (254/254), done.\u001b[K\n",
            "remote: Total 308 (delta 149), reused 148 (delta 48), pack-reused 0\u001b[K\n",
            "Receiving objects: 100% (308/308), 508.52 KiB | 4.99 MiB/s, done.\n",
            "Resolving deltas: 100% (149/149), done.\n",
            "Your branch is up to date with 'origin/development'.\n",
            "/root\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GmHNb4OxRVD8",
        "outputId": "504264d8-5abe-4ba2-a56f-6db5fe5456ba",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "import sys\n",
        "import torch\n",
        "from datasets import load_dataset\n",
        "import matplotlib.pyplot as plt\n",
        "sys.path.append('src/')\n",
        "#sys.path.insert(0,'src')\n",
        "from Charsiu import charsiu_forced_aligner, charsiu_attention_aligner"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[nltk_data] Downloading package averaged_perceptron_tagger to\n",
            "[nltk_data]     /root/nltk_data...\n",
            "[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.\n",
            "[nltk_data] Downloading package cmudict to /root/nltk_data...\n",
            "[nltk_data]   Unzipping corpora/cmudict.zip.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "N2wZBRx_WOfv"
      },
      "source": [
        "timit = load_dataset('timit_asr')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kBzpi5mSjiyL",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bb57f11a-5ebd-45c5-f593-89253b77151a"
      },
      "source": [
        "# load data\n",
        "sample = timit['train'][0]\n",
        "text = sample['text']\n",
        "audio_path = sample['file']\n",
        "print('Text transcription:%s'%(text))\n",
        "print('Audio path: %s'%audio_path)"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Text transcription:Would such an act of refusal be useful?\n",
            "Audio path: /root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "q7paWfYdROc5",
        "outputId": "d494e91a-863d-4ca5-f9bb-e8c3160d3543"
      },
      "source": [
        "# initialize model\n",
        "charsiu = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
            "/usr/local/lib/python3.7/dist-packages/transformers/configuration_utils.py:341: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.\n",
            "  \"Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 \"\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zYUK0SUKxtt_"
      },
      "source": [
        "Forced alignment with a neural forced alignment model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yHW92QgDROc4"
      },
      "source": [
        "alignment = charsiu.align(audio=audio_path,text=text)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yCmbdfpzXrQ3",
        "outputId": "e2017f28-4fc1-4394-e21e-47771a59d740"
      },
      "source": [
        "print(alignment)\n",
        "print('\\n Ground Truth \\n')\n",
        "print([(s/16000,e/16000,p) for s,e,p in zip(sample['phonetic_detail']['start'],sample['phonetic_detail']['stop'],sample['phonetic_detail']['utterance'])])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[(0.0, 0.08, '[SIL]'), (0.08, 0.15, 'W'), (0.15, 0.19, 'UH'), (0.19, 0.24, 'D'), (0.24, 0.38, 'S'), (0.38, 0.46, 'AH'), (0.46, 0.58, 'CH'), (0.58, 0.6, 'AE'), (0.6, 0.68, 'N'), (0.68, 0.82, 'AE'), (0.82, 0.93, 'K'), (0.93, 0.99, 'T'), (0.99, 1.04, 'AH'), (1.04, 1.13, 'V'), (1.13, 1.17, 'R'), (1.17, 1.22, 'AH'), (1.22, 1.33, 'F'), (1.33, 1.41, 'Y'), (1.41, 1.48, 'UW'), (1.48, 1.53, 'Z'), (1.53, 1.62, 'AH'), (1.62, 1.68, 'L'), (1.68, 1.78, 'B'), (1.78, 1.88, 'IY'), (1.88, 1.99, 'Y'), (1.99, 2.08, 'UW'), (2.08, 2.13, 'S'), (2.13, 2.23, 'F'), (2.23, 2.27, 'AH'), (2.27, 2.47, 'L'), (2.47, 2.48, '[SIL]')]\n",
            "\n",
            " Ground Truth \n",
            "\n",
            "[(0.0, 0.1225, 'h#'), (0.1225, 0.154125, 'w'), (0.154125, 0.2175, 'ix'), (0.2175, 0.25, 'dcl'), (0.25, 0.3725, 's'), (0.3725, 0.4675, 'ah'), (0.4675, 0.4925, 'tcl'), (0.4925, 0.5875, 'ch'), (0.5875, 0.6225, 'ix'), (0.6225, 0.6675, 'n'), (0.6675, 0.8425, 'ae'), (0.8425, 0.98, 'kcl'), (0.98, 0.9925, 't'), (0.9925, 1.0575, 'ix'), (1.0575, 1.1435625, 'v'), (1.1435625, 1.180125, 'r'), (1.180125, 1.2175, 'ix'), (1.2175, 1.3576875, 'f'), (1.3576875, 1.40725, 'y'), (1.40725, 1.5025, 'ux'), (1.5025, 1.574375, 'zh'), (1.574375, 1.6925, 'el'), (1.6925, 1.76, 'bcl'), (1.76, 1.785, 'b'), (1.785, 1.8825, 'iy'), (1.8825, 1.9895, 'y'), (1.9895, 2.0775, 'ux'), (2.0775, 2.165, 's'), (2.165, 2.248, 'f'), (2.248, 2.3575, 'el'), (2.3575, 2.495, 'h#')]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gJkF2z91ROc5",
        "outputId": "1b2c838d-b794-489e-aacd-7cccb149fb81"
      },
      "source": [
        "# save alignment\n",
        "charsiu.serve(audio=audio_path,text=text,save_to='./local/sample.TextGrid')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Alignment output has been saved to ./local/sample.TextGrid\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "swtLnWRlvTdr"
      },
      "source": [
        "Forced Alignment with An Attention Alignment Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tHhEVix-ugEU",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "882770bc-cc54-456f-de6f-d8cbf71bc25c"
      },
      "source": [
        "# load data\n",
        "sample = timit['train'][0]\n",
        "text = sample['text']\n",
        "audio_path = sample['file']\n",
        "print('Text transcription:%s'%(text))\n",
        "print('Audio path: %s'%audio_path)"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Text transcription:Would such an act of refusal be useful?\n",
            "Audio path: /root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8mKL4RzfuP2a"
      },
      "source": [
        "# intialize model\n",
        "charsiu = charsiu_attention_aligner('charsiu/en_w2v2_fs_10ms')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m-SpQfeIvnBu"
      },
      "source": [
        "alignment = charsiu.align(audio=audio_path,text=text)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lbbV0gpuvJW2",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "3dbf4bd3-b36a-44b8-c810-5cfc579958df"
      },
      "source": [
        "print(alignment)\n",
        "print('\\n Ground Truth \\n')\n",
        "print([(s/16000,e/16000,p) for s,e,p in zip(sample['phonetic_detail']['start'],sample['phonetic_detail']['stop'],sample['phonetic_detail']['utterance'])])"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[(0.0, 0.11, '[SIL]'), (0.11, 0.15, 'W'), (0.15, 0.2, 'UH'), (0.2, 0.27, 'D'), (0.27, 0.38, 'S'), (0.38, 0.5, 'AH'), (0.5, 0.58, 'CH'), (0.58, 0.63, 'AE'), (0.63, 0.69, 'N'), (0.69, 0.83, 'AE'), (0.83, 0.94, 'K'), (0.94, 1.0, 'T'), (1.0, 1.05, 'AH'), (1.05, 1.12, 'V'), (1.12, 1.18, 'R'), (1.18, 1.24, 'AH'), (1.24, 1.34, 'F'), (1.34, 1.43, 'Y'), (1.43, 1.5, 'UW'), (1.5, 1.58, 'Z'), (1.58, 1.64, 'AH'), (1.64, 1.73, 'L'), (1.73, 1.79, 'B'), (1.79, 1.9, 'IY'), (1.9, 2.01, 'Y'), (2.01, 2.08, 'UW'), (2.08, 2.17, 'S'), (2.17, 2.24, 'F'), (2.24, 2.31, 'AH'), (2.31, 2.41, 'L'), (2.41, 2.48, '[SIL]')]\n",
            "\n",
            " Ground Truth \n",
            "\n",
            "[(0.0, 0.1225, 'h#'), (0.1225, 0.154125, 'w'), (0.154125, 0.2175, 'ix'), (0.2175, 0.25, 'dcl'), (0.25, 0.3725, 's'), (0.3725, 0.4675, 'ah'), (0.4675, 0.4925, 'tcl'), (0.4925, 0.5875, 'ch'), (0.5875, 0.6225, 'ix'), (0.6225, 0.6675, 'n'), (0.6675, 0.8425, 'ae'), (0.8425, 0.98, 'kcl'), (0.98, 0.9925, 't'), (0.9925, 1.0575, 'ix'), (1.0575, 1.1435625, 'v'), (1.1435625, 1.180125, 'r'), (1.180125, 1.2175, 'ix'), (1.2175, 1.3576875, 'f'), (1.3576875, 1.40725, 'y'), (1.40725, 1.5025, 'ux'), (1.5025, 1.574375, 'zh'), (1.574375, 1.6925, 'el'), (1.6925, 1.76, 'bcl'), (1.76, 1.785, 'b'), (1.785, 1.8825, 'iy'), (1.8825, 1.9895, 'y'), (1.9895, 2.0775, 'ux'), (2.0775, 2.165, 's'), (2.165, 2.248, 'f'), (2.248, 2.3575, 'el'), (2.3575, 2.495, 'h#')]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "V_Y9eH5KlO4r",
        "outputId": "35419b67-3f65-4b35-8d17-6ea3fb458436",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "charsiu.serve(audio=audio_path,text=text,save_to='./local/sample.TextGrid')"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:984: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
            "  return (input_length - kernel_size) // stride + 1\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Alignment output has been saved to ./local/sample.TextGrid\n"
          ]
        }
      ]
    }
  ]
}